{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  },
  "accelerator": "GPU",
  "gpuClass": "standard"
 },
 "cells": [
  {
   "cell_type": "code",
   "source": [
    "#Uncomment to install ydata-synthetic lib\n",
    "#!pip install ydata-synthetic"
   ],
   "metadata": {
    "id": "fwXSWiYu_tl0",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Tabular Synthetic Data Generation with Gaussian Mixture\n",
    "- This notebook is an example of how to use a synthetic data generation methods based on [GMM](https://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html) to generate synthetic tabular data with numeric and categorical features.\n",
    "\n",
    "## Dataset\n",
    "- The data used is the [Adult Census Income](https://www.kaggle.com/datasets/uciml/adult-census-income) which we will fecth by importing the `pmlb` library (a wrapper for the Penn Machine Learning Benchmark data repository).\n"
   ],
   "metadata": {
    "id": "6T8gjToi_yKA",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "from pmlb import fetch_data\n",
    "\n",
    "from ydata_synthetic.synthesizers.regular import RegularSynthesizer\n",
    "from ydata_synthetic.synthesizers import ModelParameters, TrainParameters"
   ],
   "metadata": {
    "id": "Ix4gZ9iSCVZI",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Load the data"
   ],
   "metadata": {
    "id": "I0qyPwoECZ5x",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Load data\n",
    "data = fetch_data('adult')\n",
    "num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']\n",
    "cat_cols = ['workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',\n",
    "            'native-country', 'target']"
   ],
   "metadata": {
    "id": "YeFPnJVOMVqd",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 2,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Create and Train the synthetic data generator"
   ],
   "metadata": {
    "id": "68MoepO0Cpx6",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "synth = RegularSynthesizer(modelname='fast')\n",
    "synth.fit(data=data, num_cols=num_cols, cat_cols=cat_cols)"
   ],
   "metadata": {
    "id": "oIHMVgSZMg8_",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Generate new synthetic data"
   ],
   "metadata": {
    "id": "xHK-SRPyDUin",
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "synth_data = synth.sample(1000)\n",
    "print(synth_data)"
   ],
   "metadata": {
    "id": "0aa2g0RLMkqe",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "01808aa4-a700-4385-e7df-b2f7abd162a0",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": 8,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "           age  workclass         fnlwgt  education  education-num  \\\n",
      "0    38.753654          4  179993.565472          8           10.0   \n",
      "1    36.408844          4  245841.807958          9           10.0   \n",
      "2    56.251066          4  400895.076058         11           13.0   \n",
      "3    26.846605          4  240156.201048         11           10.0   \n",
      "4    29.083102          1    5601.059126         11            9.0   \n",
      "..         ...        ...            ...        ...            ...   \n",
      "995  79.281276          4   30664.183560          1           10.0   \n",
      "996  51.423132          4  414524.980527          1           10.0   \n",
      "997  17.342915          6  177716.451926         11           13.0   \n",
      "998  39.298867          4  132011.369567         15           12.0   \n",
      "999  46.977763          2   92662.371635          9           13.0   \n",
      "\n",
      "     marital-status  occupation  relationship  race  sex  capital-gain  \\\n",
      "0                 4           0             3     4    0     55.771499   \n",
      "1                 6           7             0     4    1    124.337939   \n",
      "2                 4           3             3     4    1     27.968087   \n",
      "3                 4           6             1     4    0     25.065678   \n",
      "4                 6           3             0     4    0    126.269337   \n",
      "..              ...         ...           ...   ...  ...           ...   \n",
      "995               2           0             3     4    1      4.393001   \n",
      "996               4           7             3     2    0     54.841598   \n",
      "997               4           4             4     4    0     99.394428   \n",
      "998               4          14             1     4    1     97.834797   \n",
      "999               4           8             1     4    0     51.258308   \n",
      "\n",
      "     capital-loss  hours-per-week  native-country  target  \n",
      "0       -1.271118       39.749641              39       1  \n",
      "1       -2.114950       44.488198              39       1  \n",
      "2        1.541738       40.042696              39       1  \n",
      "3        1.148560       39.952615              39       1  \n",
      "4       -1.786768       39.808085              39       0  \n",
      "..            ...             ...             ...     ...  \n",
      "995      0.224015       50.580637              39       1  \n",
      "996      1.319341        4.441194              39       1  \n",
      "997     -5.231663       39.779674              39       1  \n",
      "998      1.595817       39.731359              13       1  \n",
      "999      1.129814       39.838415              39       1  \n",
      "\n",
      "[1000 rows x 15 columns]\n"
     ]
    }
   ]
  }
 ]
}